#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author ：hhx
@Date ：2022/5/20 20:20 
@Description ：添加线性层(效果一般)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # encoder
        self.enc1 = nn.Conv2d(in_channels=10, out_channels=8, kernel_size=3)
        self.enc2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3)
        # self.enc3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.enc3 = nn.Linear(in_features=28 * 28 * 16, out_features=64)
        # decoder
        self.dec1 = nn.Linear(in_features=64, out_features=28 * 28 * 16)
        # self.dec1 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3)
        self.dec2 = nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3)
        self.dec3 = nn.ConvTranspose2d(in_channels=8, out_channels=10, kernel_size=3)

    def forward(self, x):
        # print(x.shape)
        self.x1 = F.relu(self.enc1(x))
        self.x2 = F.relu(self.enc2(self.x1)).view(-1, 28 * 28 * 16)
        self.x3 =self.enc3(self.x2)
        self.x4 = F.relu(self.dec1( F.relu(self.x3))).reshape(-1, 16, 28, 28)
        self.x5 = F.relu(self.dec2(self.x4))
        x = F.relu(self.dec3(self.x5))
        return x
